import  numpy as np

memory = np.zeros((100, 5, 40*2+1+8), dtype=np.float32)
indices = np.random.choice(100, size=10)

BATCH = memory[indices, :, :]
batch = BATCH[:, 0, :]
bs_ = batch[:, :40]
bs = batch[:, 40: 40 * 2]
ba = batch[:, -8 - 1: -1]
br = batch[:, -1:]

S_A_ = BATCH[:, :, 40:]
print(BATCH.shape)
print(batch.shape)
print(bs_.shape)
print(bs.shape)
print(ba.shape)
print(br.shape)
print(S_A_.shape)

print(ba[0:].shape)